#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import unicode_literals
from __future__ import division, with_statement
'''
Copyright 2015, 陈同 (chentong_biology@163.com).  
===========================================================
'''
__author__ = 'chentong & ct586[9]'
__author_email__ = 'chentong_biology@163.com'
#=========================================================
desc = '''
Program description:
    This is designed to summarize reads distribution output by `bam_stat.py`.
'''

import sys
import os
from json import dumps as json_dumps
from time import localtime, strftime 
timeformat = "%Y-%m-%d %H:%M:%S"
from optparse import OptionParser as OP
import re
from tools import *
#from multiprocessing.dummy import Pool as ThreadPool

#from bs4 import BeautifulSoup
reload(sys)
sys.setdefaultencoding('utf8')

debug = 0

def fprint(content):
    """ 
    This is a Google style docs.

    Args:
        param1(str): this is the first param
        param2(int, optional): this is a second param
            
    Returns:
        bool: This is a description of what is returned
            
    Raises:
        KeyError: raises an exception))
    """
    print json_dumps(content,indent=1)

def cmdparameter(argv):
    if len(argv) == 1:
        global desc
        print >>sys.stderr, desc
        cmd = 'python ' + argv[0] + ' -h'
        os.system(cmd)
        sys.exit(1)
    usages = "%prog -f file"
    parser = OP(usage=usages)
    parser.add_option("-f", "--files", dest="filein",
        metavar="FILEIN", help="`,` or ` ` separated a list of files. Generated by `bam_stats.py` for initial BAM generated by bwa-mem")
    parser.add_option("-l", "--labels", dest="label",
        metavar="LABEL", help="`,` or ` ` separated a list of labels to label each file. It must have same order as files.")
    parser.add_option("-F", "--files2", dest="filein2",
        metavar="FILEIN", help="`,` or ` ` separated a list of files. Generate by `bam_stats.py` for post-processed BAM for downstream usages.")
    parser.add_option("-o", "--output-prefix", dest="out_prefix",
        help="The prefix of output files.")
    parser.add_option("-r", "--report-dir", dest="report_dir",
        default='report', help="Directory for report files. Default 'report'.")
    parser.add_option("-R", "--report-sub-dir", dest="report_sub_dir",
        default='2_mapping_quality', help="Directory for saving report figures and tables. This dir will put under <report_dir>,  so only dir name is needed. Default '2_mapping_quality'.")
    parser.add_option("-d", "--doc-only", dest="doc_only",
        default=False, action="store_true", help="Specify to only generate doc.")
    parser.add_option("-v", "--verbose", dest="verbose",
        action="store_true", help="Show process information")
    parser.add_option("-D", "--debug", dest="debug",
        default=False, action="store_true", help="Debug the program")
    (options, args) = parser.parse_args(argv[1:])
    assert options.filein != None, "A filename needed for -i"
    return (options, args)
#--------------------------------------------------------------------

def readFile(file):
    sumD = {}
    for line in open(file):
        colon = line.find(':')
        if colon != -1:
            type = line[:colon]
            value = int(line[colon+1:])
            sumD[type] = value
        else:
            if line.strip() and (not line.startswith('#')):
                type, value = line.rsplit(' ', 1)
                type = type.strip()
                value = int(value)
                sumD[type] = value
    if debug:
        print >>sys.stderr, sumD
    
    return sumD
#-, ------------------------
def plot(fileL):
    for file in fileL:
        cmd = "s-plot barPlot -f " + file
        os.system(cmd)
#--------------------------------------
def plot_melt(total_melt, nameL):
    x_level = ["'"+i+"'" for i in nameL]
    x_level = '"'+','.join(x_level)+'"'
    cmd = "s-plot barPlot -m TRUE -a Sample -R 45 -B set -O 1 -w 25 -u 30 -f "\
            +total_melt + ' -y  \'Reads count\' -x \'Samples\' -L ' + x_level
    os.system(cmd)
    cmd = "s-plot barPlot -m TRUE -a Sample -R 45 -B set -O 1 -d fill -w 25 -u 30 -f "\
            +total_melt + ' -y  \'Reads percent\' -x \'Samples\' -L ' + x_level
    os.system(cmd)
#--------------------------------------
def generateDoc(report_dir, report_sub_dir, totalTable, total_melt, curation_label):
    dest_dir = report_dir+'/'+report_sub_dir+'/'
    os.system('mkdir -p '+dest_dir)
    pdf_stack = total_melt+'.stackBars.pdf'
    pdf_fill = total_melt+'.fillBars.pdf'
    copy(dest_dir, totalTable)
    copypdf(dest_dir, pdf_stack, pdf_fill)

    print "\n## Reads mapping statistics\n"
    curation_label = "Reads_mapping_statistics"
    knitr_read_txt(report_dir,  curation_label)
    
    print """每个样品测序序列比对总结见 (Table \@ref(tab:read-map-sum-table), Figure \@ref(fig:read-map-sum-fig) and Figure \@ref(fig:read-map-sum-percent-fig))。

```{r read-map-sum-table-explain, results="asis"}
read_map_sum_table_explain='Symbol;Explanation;In_figure
**Total**;预处理后用于比对的中的reads数;Figure \\\\@ref(fig:read-map-sum-fig)的"Reads usage"子图
**Map**;比对到基因组的reads数;Figure \\\\@ref(fig:read-map-sum-fig)的"Map attribute"子图
**UnMap**;未比对到基因组的reads数;Figure \\\\@ref(fig:read-map-sum-fig)的"Map attribute"子图
**Unique_map**;在基因组上有唯一比对位置的reads数;Figure \\\\@ref(fig:read-map-sum-fig)的"Map attribute"子图
**Multi_map**;在基因组上有多个比对位置的reads数;Figure \\\\@ref(fig:read-map-sum-fig)的"Map attribute"子图
**Proper_pair**;双端reads都比对正确的reads数;
**Read_1_map**;左端reads比对回基因组的数目;Figure \\\\@ref(fig:read-map-sum-fig)的"Each end map"子图
**Read_2_map**;右端reads比对回基因组的数目;Figure \\\\@ref(fig:read-map-sum-fig)的"Each end map"子图
**Read_pos_map**;比对到基因组正链的reads数;Figure \\\\@ref(fig:read-map-sum-fig)的"Each strand map"子图
**Read_neg_map**;比对到基因组负链的reads数;Figure \\\\@ref(fig:read-map-sum-fig)的"Each strand map"子图
**Final_kept_reads**;过滤后用于下游分析的reads数;Figure \\\\@ref(fig:read-map-sum-fig)的"Reads usage"子图'
read_map_sum_table_explain_data <- read.table(text=read_map_sum_table_explain, sep=';', quote="", header=T)
knitr::kable(read_map_sum_table_explain_data, format="markdown")
```
"""

    print "Reads比对数据的统计可以帮助判断测序的质量、准确度、有无偏好和异常等。通常来讲左端reads的比对数和右端reads的比对数应该相当；对于DNA测序来说，比对到正链和负链的reads应该相当。如果过滤后用于下游分析的reads数偏少，则需要慎重考虑。\n"

    print "原始数据或者PDF格式的文件可以点击XLS或PDF下载。\n"
    
    totalTableNew = report_sub_dir+'/'+os.path.split(totalTable)[-1]
    print "(ref:read-map-sum-table) Summary of mapped reads. [XLS]({})\n".format(totalTableNew)
    
    print '''```{{r read-map-sum-table}}
map_table <- read.table("{totalTableNew}", sep="\\t", header=T, row.names=1, quote="", comment="")
knitr::kable(map_table, booktabs=T, caption="(ref:read-map-sum-table)", format.args = list(big.mark = ','))
```
'''.format(totalTableNew=totalTableNew)
    
    pdf_stack = report_sub_dir+'/'+os.path.split(pdf_stack)[-1]
    png_stack = pdf_stack.replace('pdf', 'png')
    
    print "(ref:read-map-sum-fig) Summary of mapped reads. 1 Million = 10^6^. \
*Each end map* 指左端reads和右端reads各自比对回基因组的数目；\
*Each strand map* 指reads分别比对到正链和负链的数目；\
*Map attribute* 指不同比对类型reads的数目比较；\
*Reads usage* 指过滤后可用reads和弃掉reads的数目。[PDF]({})\n".format(pdf_stack)

    print '''```{{r read-map-sum-fig, fig.cap="(ref:read-map-sum-fig)"}}
knitr::include_graphics("{png}")
```
'''.format(png=png_stack)

    pdf_fill = report_sub_dir+'/'+os.path.split(pdf_fill)[-1]
    png_fill = pdf_fill.replace('pdf', 'png')
    
    print "(ref:read-map-sum-percent-fig) Summary of mapped reads (relative percent).\
*Each end map* 指左端reads和右端reads各自比对回基因组的比例；\
*Each strand map* 指reads分别比对到正链和负链的比例；\
*Map attribute* 指不同比对类型reads的数目比较；\
*Reads usage* 指过滤后可用reads和弃掉reads的比例。[PDF]({})\n".format(pdf_fill)

    print '''```{{r read-map-sum-percent-fig, fig.cap="(ref:read-map-sum-percent-fig)"}}
knitr::include_graphics("{png}")
```
'''.format(png=png_fill)
#--------------------------------


def main():
    options, args = cmdparameter(sys.argv)
    #-----------------------------------
    file = options.filein
    fileL = re.split(r'[, ]*', file.strip())
    label = options.label
    labelL = re.split(r'[, ]*', label.strip())
    postFile = options.filein2
    postFileL = re.split(r'[, ]*', postFile.strip())
    verbose = options.verbose
    op = options.out_prefix
    report_dir = options.report_dir
    report_sub_dir = options.report_sub_dir
    global debug
    debug = options.debug
    doc_only = options.doc_only
    #-----------------------------------
    aDict = {}
    totalTable = op+".map_summary.xls"
    total_melt = op+'.melt_summary.xls'
    curation_label = os.path.split(sys.argv[0])[-1].replace('.', '_')
    if doc_only:
        generateDoc(report_dir, report_sub_dir, totalTable, total_melt, curation_label)
        return 0


    totalTable_fh = open(totalTable, 'w')
    print >>totalTable_fh, "Sample\tTotal\tMap\tUnMap\tUnique_map\tMulti_map\tProper_pair\tRead_1_map\tRead_2_map\tRead_pos_map\tRead_neg_map\tFinal_kept_reads"

    un_unique_multi = op+'.un_unique_multi.xls'
    un_unique_multi_fh = open(un_unique_multi, 'w')
    print >>un_unique_multi_fh, "Sample\tUnMap\tUnique_map\tMulti_map"

    read_1_2 = op+'.read_1_2_map.xls'
    read_1_2_fh = open(read_1_2, 'w')
    print >>read_1_2_fh, "Sample\tRead_1_map\tRead_2_map"
    
    read_pos_neg = op+'.read_pos_neg_map.xls'
    read_pos_neg_fh = open(read_pos_neg, 'w')
    print >>read_pos_neg_fh, "Sample\tRead_pos_map\tRead_neg_map"

    read_used_unused = op+'.read_used_unused.xls'
    read_used_unused_fh = open(read_used_unused, 'w')
    print >>read_used_unused_fh, "Sample\tKept_reads\tDiscarded_reads"

    total_melt_fh = open(total_melt, 'w')

    print >>total_melt_fh, "Sample\tvariable\tvalue\tset"
    norm_factor = open("norm_factor", 'w')
    i = -1
    for file in fileL:
        i += 1
        label = labelL[i]
        sumD = readFile(file)
        sum_postD = readFile(postFileL[i])
        total = sumD['Total records'] - sumD['Non primary hits']
        UnMap = sumD['Unmapped reads']
        Unique_map = sumD['mapq >= mapq_cut (unique)']
        Multi_map = sumD['mapq < mapq_cut (non-unique)']
        assert total == UnMap+Unique_map+Multi_map, sumD
        Proper_pair = sumD['Reads mapped in proper pairs']
        Read_1 = sumD['Read-1']
        Read_2 = sumD.get('Read-2', 0)
        Read_pos = sumD["Reads map to '+'"]
        Read_neg = sumD["Reads map to '-'"]
        Final_used = sum_postD['Total records']
        
        print >>norm_factor, "{}\t{}".format(label, Final_used)
        lineL = [label, total, Unique_map+Multi_map, UnMap, 
            Unique_map, Multi_map, Proper_pair, Read_1, Read_2, 
            Read_pos, Read_neg, Final_used]

        lineL = [str(j) for j in lineL]

        print >>totalTable_fh, '\t'.join(lineL)

        print >>total_melt_fh, "{}\t{}\t{}\t{}".format(label,'UnMap', UnMap, "Map attribute")
        print >>total_melt_fh, "{}\t{}\t{}\t{}".format(label,'Unique_map', Unique_map, "Map attribute")
        print >>total_melt_fh, "{}\t{}\t{}\t{}".format(label,'Multi_map', Multi_map, "Map attribute")

        print >>total_melt_fh, "{}\t{}\t{}\t{}".format(label,'Read_1', Read_1, "Each end map")
        print >>total_melt_fh, "{}\t{}\t{}\t{}".format(label,'Read_2', Read_2, "Each end map")
        print >>total_melt_fh, "{}\t{}\t{}\t{}".format(label,'Read_pos', Read_pos, "Each strand map")
        print >>total_melt_fh, "{}\t{}\t{}\t{}".format(label,'Read_neg', Read_neg, "Each strand map")
        print >>total_melt_fh, "{}\t{}\t{}\t{}".format(label,'Kept_reads', Final_used, "Reads usage")
        print >>total_melt_fh, "{}\t{}\t{}\t{}".format(label,'Discarded_reads', total-Final_used, "Reads usage")

        un_unique_multiL = [label]
        un_unique_multiL.extend(lineL[3:6])
        print >>un_unique_multi_fh, '\t'.join(un_unique_multiL)
        
        print >>read_1_2_fh, '{}\t{}\t{}'.format(label, Read_1, Read_2)

        print >>read_pos_neg_fh,'{}\t{}\t{}'.format(label, Read_pos, Read_neg) 

        print >>read_used_unused_fh, '{}\t{}\t{}'.format(label, Final_used, total-Final_used) 

    norm_factor.close()
    totalTable_fh.close()
    un_unique_multi_fh.close()
    read_1_2_fh.close()
    read_pos_neg_fh.close()
    read_used_unused_fh.close()
    total_melt_fh.close()

    #plot([un_unique_multi, read_1_2, read_pos_neg, read_used_unused])
    plot_melt(total_melt, labelL)

    generateDoc(report_dir, report_sub_dir, totalTable, total_melt, curation_label)
    ###--------multi-process------------------

if __name__ == '__main__':
    startTime = strftime(timeformat, localtime())
    main()
    endTime = strftime(timeformat, localtime())
    fh = open('python.log', 'a')
    print >>fh, "%s\n\tRun time : %s - %s " % \
        (' '.join(sys.argv), startTime, endTime)
    fh.close()
    ###---------profile the program---------
    #import profile
    #profile_output = sys.argv[0]+".prof.txt")
    #profile.run("main()", profile_output)
    #import pstats
    #p = pstats.Stats(profile_output)
    #p.sort_stats("time").print_stats()
    ###---------profile the program---------


